import torch
from torch import nn
from torch.nn import functional as F

from einops import rearrange, reduce

from typing import List, Callable, Union, Any, TypeVar, Tuple
Tensor = TypeVar('torch.tensor')


class BetaVAE(nn.Module):

    num_iter = 0 

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 condition_dim: int,
                 hidden_dims: List = None,
                 kl_std=1.0,
                 beta: int = 4,
                 gamma:float = 10., 
                 max_capacity: int = 25,
                 Capacity_max_iter: int = 1e5, 
                 loss_type:str = 'B',
                 **kwargs) -> None:
        super(BetaVAE, self).__init__()

        self.latent_dim = latent_dim
        self.beta = beta
        self.gamma = gamma
        self.loss_type = loss_type
        self.C_max = torch.Tensor([max_capacity])
        self.C_stop_iter = Capacity_max_iter
        self.in_channels = in_channels

        self.kl_std = kl_std

        modules = []
        if hidden_dims is None:
            
            hidden_dims = [512, 512, 512, 512, 512]

        self.hidden_dims = hidden_dims

        
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.condition_encoder = nn.Sequential(
            nn.Linear(condition_dim, latent_dim),
            nn.ReLU()
        )

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)  
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 


        
        modules = []

        self.decoder_input = nn.Linear(latent_dim*2, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                    hidden_dims[i + 1],
                                    kernel_size=3,
                                    stride = 2,
                                    padding=1,
                                    output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= self.in_channels, 
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())


        

    def encode(self, enc_input: Tensor) -> List[Tensor]:
        result = enc_input
        result = self.encoder(enc_input)  
        result = torch.flatten(result, start_dim=1) 

        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        '''
        z: latent vector: B, D (D = latent_dim*3)
        '''
        result = self.decoder_input(z) 
        result = result.view(-1, int(result.shape[-1]/4), 2, 2)  
        result = self.decoder(result)
        result = self.final_layer(result) 
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, data: Tensor, atc: Tensor, **kwargs) -> Tensor:
        atc_emb = self.condition_encoder(atc.float())
        mu, log_var = self.encode(data)
        z = self.reparameterize(mu, log_var)
        z = torch.cat([z, atc_emb], dim=1)
        return  [self.decode(z), data, mu, log_var, z]

    
    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        self.num_iter += 1
        recons = args[0]
        data = args[1]
        mu = args[2]
        log_var = args[3]
        kld_weight = kwargs['M_N']
      
        if self.kl_std == 'zero_mean':
            latent = self.reparameterize(mu, log_var) 
            
            l2_size_loss = torch.sum(torch.norm(latent, dim=-1))
            kl_loss = l2_size_loss / latent.shape[0]

        else:
            std = torch.exp(0.5 * log_var)
            gt_dist = torch.distributions.normal.Normal( torch.zeros_like(mu), torch.ones_like(std)*self.kl_std )
            sampled_dist = torch.distributions.normal.Normal( mu, std )
            
            

            kl = torch.distributions.kl.kl_divergence(sampled_dist, gt_dist) 
            kl_loss = reduce(kl, 'b ... -> b (...)', 'mean').mean()

        return kld_weight * kl_loss

    def sample(self,
               num_samples:int,
                **kwargs) -> Tensor:
        z = torch.randn(num_samples, self.latent_dim)

        z = z.cuda()

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, atc: Tensor, **kwargs) -> Tensor:

        return self.forward(x, atc)[0]

    def get_latent(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return z 